import cv2
import os
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import io
import torch
import re
import warnings
warnings.filterwarnings("ignore")
import time
from modelscope import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import json 
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.build_sam import build_sam2

def new_generate_augment_imgs(image_path):
    """
    输入一张图像的路径，生成原图、水平翻转、缩小1/2四张图像，
    并保存到 output 文件夹。
    """
    # 读取图像
    image = cv2.imread(image_path)
    if image is None:
        print("图像加载失败，请检查文件路径是否正确。")
        return

    output_dir = os.path.splitext(os.path.basename(image_path))[0]
    # 创建输出目录
    output_dir = f"augment/{output_dir}"
    os.makedirs(output_dir, exist_ok=True)

    img_path_list = []
    # 原图
    original_path = os.path.join(output_dir, "original.jpg")
    cv2.imwrite(original_path, image)
    img_path_list.append(original_path)

    # 水平翻转（左右镜像）
    h_flip = cv2.flip(image, 1)
    horizontal_path = os.path.join(output_dir, "horizontal_flip.jpg")
    cv2.imwrite(horizontal_path, h_flip)
    img_path_list.append(horizontal_path)

    # 获取尺寸并计算一半大小
    height, width = image.shape[:2]
    new_size = (width // 2, height // 2)

    # 缩放
    resized_img = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
    resized_path = os.path.join(output_dir, "resized.jpg")
    # 保存图片
    cv2.imwrite(resized_path, resized_img)
    img_path_list.append(resized_path)

    return img_path_list

def build_generation_prompt(query):
    template = """Locate "{query}", report the bboxes coordinates in JSON format."""
    return template.format(query=query)


def generate_response(image_path, prompt, model, processor):
  
    messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image_path,
            },
            {"type": "text", "text": prompt},
        ],
    }
    ]
    text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False,
    num_beams=1)
    generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]

def extract_json_from_response(response_text):
        """
        Extract JSON content from model response using triple quotes
        
        Args:
            response_text (str): Model response text
            
        Returns:
            dict or list: Parsed JSON content
        """
        json_pattern = r"```json\s*(.*?)\s*```"
        match = re.search(json_pattern, response_text, re.DOTALL)
        
        if match:
            json_str = match.group(1).strip()
            try:
                return json.loads(json_str)
            except json.JSONDecodeError:
                print(f"Error parsing JSON: {json_str}")
                return None

def extract_coords_from_response(response_text):
        """
        Extract JSON content from model response using triple quotes
        
        Args:
            response_text (str): Model response text
            
        Returns:
            dict or list: Parsed JSON content
        """
        json_pattern = r"```json\s*(.*?)\s*```"
        match = re.search(json_pattern, response_text, re.DOTALL)
        
        if match:
            json_str = match.group(1).strip()
            try:
                json_data =  json.loads(json_str)
                bboxes = []
                if json_data and isinstance(json_data, list):
                    for item in json_data:
                        for key in ['bbox', 'bbox_2d']:
                            if key in item and len(item[key]) == 4:
                                bbox = [
                                    int(item[key][0]),
                                    int(item[key][1]),
                                    int(item[key][2]),
                                    int(item[key][3])
                                ]
                                bboxes.append(bbox)
                    return bboxes[0]
                else:
                    return None

            except json.JSONDecodeError:
                print(f"Error parsing JSON: {json_str}")
                return None

def column_means(data):
    # 检查是否为空
    if not data or not data[0]:
        return []

    # 获取列数
    num_cols = len(data[0])
    
    # 初始化一个列表来存储每列的和
    sums = [0] * num_cols
    counts = [0] * num_cols

    # 遍历每一行，累加到对应列
    for row in data:
        for i, val in enumerate(row):
            sums[i] += val
            counts[i] += 1

    # 计算平均值
    return [sums[i] / counts[i] for i in range(num_cols)]

def process_coords(coords_list, width, height):
    all_coords = []

    original_coord = coords_list[0]
    if original_coord:
        original_coord = [max(0, original_coord[0]), max(0, original_coord[1]), min(width, original_coord[2]), min(height, original_coord[3])]
        all_coords.append(original_coord)

    horizontal_coord = coords_list[1]
    if horizontal_coord:
        horizontal_coord = [width-1-horizontal_coord[2], horizontal_coord[1], width-1-horizontal_coord[0], horizontal_coord[3]]
        horizontal_coord = [float(f"{x:.2f}") for x in horizontal_coord]
        horizontal_coord = [max(0, horizontal_coord[0]), max(0, horizontal_coord[1]), min(width, horizontal_coord[2]), min(height, horizontal_coord[3])]
        all_coords.append(horizontal_coord)

    resized_coord = coords_list[2]
    if resized_coord:
        resized_coord = [resized_coord[0]*2, resized_coord[1]*2, resized_coord[2]*2, resized_coord[3]*2]
        resized_coord = [float(f"{x:.2f}") for x in resized_coord]
        resized_coord = [max(0, resized_coord[0]), max(0, resized_coord[1]), min(width, resized_coord[2]), min(height, resized_coord[3])]
        all_coords.append(resized_coord)

    mean_coord = column_means(all_coords)
    mean_coord = [float(f"{x:.2f}") for x in mean_coord]
    all_coords.append(mean_coord)


    return all_coords  

def iou(bbox1, bbox2):
    """
    计算两个归一化 bbox 的 IoU。
    bbox: [x_min, y_min, x_max, y_max]，归一化到 [0, 1]
    """
    x1_min, y1_min, x1_max, y1_max = bbox1
    x2_min, y2_min, x2_max, y2_max = bbox2

    # 计算交集区域的坐标
    inter_x_min = max(x1_min, x2_min)
    inter_y_min = max(y1_min, y2_min)
    inter_x_max = min(x1_max, x2_max)
    inter_y_max = min(y1_max, y2_max)

    # 计算交集面积
    inter_width = max(0, inter_x_max - inter_x_min)
    inter_height = max(0, inter_y_max - inter_y_min)
    inter_area = inter_width * inter_height

    # 计算每个 bbox 的面积
    area1 = (x1_max - x1_min) * (y1_max - y1_min)
    area2 = (x2_max - x2_min) * (y2_max - y2_min)

    # 并集面积
    union_area = area1 + area2 - inter_area

    if union_area == 0:
        return 0.0

    return inter_area / union_area

def remove_duplicate_bboxes(bboxes, iou_threshold=0.8):
    """
    去除重复的 bbox。如果两个 bbox 的 IoU >= threshold，
    则保留面积较大的那个。
    """
    to_remove = set()
    n = len(bboxes)

    for i in range(n):
        for j in range(i + 1, n):
            if i in to_remove or j in to_remove:
                continue
            iou_score = iou(bboxes[i], bboxes[j])
            if iou_score >= iou_threshold:
                # 删除面积较小的那个
                area_i = (bboxes[i][2] - bboxes[i][0]) * (bboxes[i][3] - bboxes[i][1])
                area_j = (bboxes[j][2] - bboxes[j][0]) * (bboxes[j][3] - bboxes[j][1])
                if area_i > area_j:
                    to_remove.add(j)
                else:
                    to_remove.add(i)

    # 构建结果列表
    result = [bboxes[i] for i in range(n) if i not in to_remove]
    return result

def draw_bboxes_on_image(image_path, bboxes, colors=None):
    """
    在图像上绘制多个 bounding box，并按顺序编号。
    
    参数:
        image_path (str): 原始图像的路径；
        bboxes (list of list): 归一化 bbox 列表，格式为 [x_min, y_min, x_max, y_max]；
        colors (list of tuple): 每个框的颜色（RGB元组），可选；
        output_path (str): 输出图像保存路径；
        show (bool): 是否显示图像，默认显示。
    """
    # 默认颜色列表
    if colors is None:
        colors = [
            (255, 0, 0),     # 红色
            (0, 255, 0),     # 绿色
            (0, 0, 255),     # 蓝色
            (255, 255, 0),   # 黄色
            (255, 0, 255),   # 品红
            (0, 255, 255)    # 青色
        ]

    # 加载图像
    image = Image.open(image_path).convert("RGB")
    draw = ImageDraw.Draw(image)
    width, height = image.size

    # 字体设置
    try:
        font = ImageFont.truetype("arial.ttf", 20)
    except:
        font = ImageFont.load_default()

    # 绘制每个 bbox
    for idx, bbox in enumerate(bboxes):
        left, top, right, bottom = bbox

        # 归一化坐标转像素坐标
        color = colors[idx % len(colors)]

        # 绘制矩形框
        draw.rectangle([left, top, right, bottom], outline=color, width=3)

        # 添加编号文本
        text = str(idx + 1)
        # text_size = draw.textsize(text, font=font)
        text_bbox = draw.textbbox((0, 0), text, font=font)
        text_size = (text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1])
        text_position = (int(left) + 5, int(top) - text_size[1] if top > text_size[1] else int(top))

        # 文字背景框
        draw.rectangle(
            [text_position[0], text_position[1],
             text_position[0] + text_size[0],
             text_position[1] + text_size[1]],
            fill=color
        )
        draw.text(text_position, text, fill="white", font=font)

    # 保存图像
    output_dir = os.path.splitext(os.path.basename(image_path))[0]
    # 创建输出目录
    output_dir = f"augment/{output_dir}"
    output_path = output_dir + '/with_multi_boxes.jpg'
    image.save(output_path)
    # print(f"图像已保存至: {output_path}")

    # 显示图像
    return output_path

def build_select_prompt(query, coords):
    prompt = f'Please analyze the image provided below and determine which of the bounding boxes better captures the "{query}".\n'

    prompt += """
Your task is to:
1. Identify which bounding box more accurately includes the entire target object.
2. Provide a brief explanation for your choice.

Note: The format of the bounding box is [x_min, y_min, x_max, y_max], representing the top-left and bottom-right coordinates.

The coordinates for each bounding box are as follows:
"""
    d = 0
    colors = ['red', 'green', 'blue', 'yellow']

    for a in coords:
        prompt += f'- **Bbox {d+1} ({colors[d]})**: {a}\n'
        d += 1
    prompt += """
Return your answer in the following format:

Best Box: <Box Number>
Reasoning: <Explanation>"""
    return prompt

def extract_best_box_number(text):
    match = re.search(r'\d+', text)
    if match:
        return match.group(0)
    else:
        return None

def draw_bbox_on_image(image_path, normalized_bbox_coords, image_name):
    """
    在图片上绘制边界框并保存新图片。
    
    :param image_path: 输入图片的路径
    :param normalized_bbox_coords: 归一化后的边界框坐标 (x1, y1, x2, y2)，范围在 [0, 1]
    :param output_path: 输出图片的路径
    """
    # 打开图片
    img = Image.open(image_path)
    
    # # 获取图片的宽度和高度
    # width, height = img.size
    
    # 将归一化坐标转换为像素坐标
    pixel_x1, pixel_y1, pixel_x2, pixel_y2 = normalized_bbox_coords

    # 创建绘图对象
    draw = ImageDraw.Draw(img)
    
    # 绘制红色边界框
    draw.rectangle([pixel_x1, pixel_y1, pixel_x2, pixel_y2], outline="red", width=3)  # 宽度为 3 的红色边框
    output_dir = os.path.splitext(os.path.basename(image_path))[0]
    # 创建输出目录
    output_dir = f"augment/{output_dir}"
    output_path = output_dir + f'/{image_name}.jpg'

    # 保存新图片
    img.save(output_path)
    # print(f"边界框已绘制并保存到 {output_path}")
    return output_path

def build_optimize_prompt(query, current_box):
    prompt = f"""Please analyze the image provided below and evaluate whether the current bounding box accurately captures the "{query}".\n"""
    prompt += f"The current bounding box coordinates are: {current_box}, where:\n"
    prompt += "- `x_min` = {:.2f} (left edge)\n".format(current_box[0])
    prompt += "- `y_min` = {:.2f} (top edge)\n".format(current_box[1])
    prompt += "- `x_max` = {:.2f} (right edge)\n".format(current_box[2])
    prompt += "- `y_max` = {:.2f} (bottom edge)\n".format(current_box[3])

    prompt += """

Your task is to:
1. Assess whether the current bounding box adequately includes the entire target object.
2. If the current box does not perfectly capture the target object or leaves unnecessary margins, suggest an optimized bounding box with improved coordinates.

Note: The current bounding box may not be accurate. Please carefully analyze the image and improve the coordinates if necessary.

Return your response in the following format:

Current Box: [x_min, y_min, x_max, y_max]
Optimized Box: [x_min_optimized, y_min_optimized, x_max_optimized, y_max_optimized]
Reasoning: <Explanation of why the optimization was made>
"""
    return prompt

def extract_optimized_box(text):
    # 正则匹配 "Optimized Box: [...]" 中的内容
    pattern = r'Optimized\s+Box:\s*\[([^\]]+)\]'
    match = re.search(pattern, text)
    
    if match:
        coords_str = match.group(1)  # 获取方括号内的内容
        try:
            # 提取数字和小数点，忽略空格和逗号
            coords = [float(num) for num in re.findall(r'[-+]?\d*\.\d+|\d+', coords_str)]
            # 确保返回前四个数字
            return coords[:4]
        except Exception as e:
            print("Error parsing coordinates:", e)
            return None
    else:
        print("Could not find 'Optimized Box' in response.")
        return None

def generate_masks(segmentation_model, image_path, bboxes):
        """
        Generate segmentation masks for given image, bounding boxes and points
        
        Args:
            image (PIL.Image): Input image
            bboxes (list): List of bounding boxes
            points (list): List of points
            
        Returns:
            numpy.ndarray: Combined segmentation mask
        """
        
        image = Image.open(image_path).convert("RGB")
        # print(image.shape)
        img_height, img_width = image.height, image.width
        mask_all = np.zeros((img_height, img_width), dtype=bool)
        # print(mask_all.shape)
        points = None
        if not bboxes:
            return mask_all
        
        try:
            segmentation_model.set_image(image)
            if not points:
                points = []
            if len(points) != len(bboxes):
                points.extend([None] * (len(bboxes) - len(points)))
            
            for bbox, point in zip(bboxes, points):
                # print(bbox)
                masks, scores, _ = segmentation_model.predict(
                    box=bbox
                )
                sorted_ind = np.argsort(scores)[::-1]
                # print(masks.shape)
                masks = masks[sorted_ind]
                # print(masks.shape)
                mask = masks[0].astype(bool)
                mask_all = np.logical_or(mask_all, mask)
            return mask_all


        except Exception as e:
            print(f"Error generating masks: {e}")
            return mask_all


def seg_agent_qwenvl(qwen_model, processor, segmentation_model, image_path, query):

    img = Image.open(image_path)
    width, height = img.size
    # print(img.size)
    #生成多张增广图像
    img_path_list = new_generate_augment_imgs(image_path)
    # print(img_path_list)
    #让LLM输出bbox
    prompt = build_generation_prompt(query)
    # print(prompt)
    coords_list = {}
    tt = 0
    for img in img_path_list:
        print(img)
        response = generate_response(img, prompt, qwen_model, processor)
        # print(response)
        aaaa = extract_coords_from_response(response)
        coords_list[tt] = aaaa
        tt += 1

    # print(coords_list)

    #映射回原来的
    coords_list = process_coords(coords_list, width, height)
    # print(coords_list)

    #去重，去除IoU较高的框
    filtered_coords_list = remove_duplicate_bboxes(coords_list)
    # print(filtered_coords_list)

    #SoM，将框加到原图上
    bbox_img_path = draw_bboxes_on_image(image_path, filtered_coords_list)
    # print(bbox_img_path)

    #构建选择的prompt
    select_prompt = build_select_prompt(query, filtered_coords_list)
    # print(select_prompt)

    #选择最好的bbox
    response = generate_response(bbox_img_path, select_prompt, qwen_model, processor)
    # print(response)
    best_box_index = extract_best_box_number(response)
    # print(best_box_index)

    #选出最合适的框，再将其加到原图上
    best_box_coord = filtered_coords_list[int(best_box_index) - 1]
    best_box_img_path = draw_bbox_on_image(image_path, best_box_coord, 'with_selected_box')

     #让LLM继续改进框
    prompt = build_optimize_prompt(query, best_box_coord)
    # print(prompt)

    response = generate_response(best_box_img_path, prompt, qwen_model, processor)
    # print(response)
    del qwen_model
    del processor
    torch.cuda.empty_cache()
    # print(response)
    # #得到优化后的坐标
    refined_coord = extract_optimized_box(response)
    if not refined_coord:
        refined_coord = best_box_coord
    # print(refined_coord)

    # #保存优化后的框叠加到原图
    best_box_img_path = draw_bbox_on_image(image_path, refined_coord, 'with_refined_box')

     #使用SAM进行分割
    mask = generate_masks(segmentation_model, image_path, bboxes=[refined_coord])
    return mask


if __name__ == '__main__':
    print('hello')

    